{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {},
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W1D4_MachineLearning/W1D4_Tutorial1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a> &nbsp; <a href=\"https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/NeoNeuron/professional-workshop-3/master/tutorials/W4_ModelFitting/W4_Tutorial2.ipynb\" target=\"_parent\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" alt=\"Open in Kaggle\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "# Model Fitting: Generalized Linear Models (GLMs)\n",
    "\n",
    "__Content creators:__ Pierre-Etienne H. Fiquet, Ari Benjamin, Jakob Macke\n",
    "\n",
    "__Content modified:__ Kai Chen\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "This is tutorial about Generalized Linear Models (GLMs), which are a fundamental framework for supervised learning.\n",
    "\n",
    "In this tutorial, the objective is to model a retinal ganglion cell spike train by fitting a temporal receptive field. First with a Linear-Gaussian GLM (also known as ordinary least-squares regression model) and then with a Poisson GLM (aka \"Linear-Nonlinear-Poisson\" model). \n",
    "<!-- In the next tutorial, we’ll extend to a special case of GLMs, logistic regression, and learn how to ensure good model performance. -->\n",
    "\n",
    "This tutorial is designed to run with retinal ganglion cell spike train data from [Uzzell & Chichilnisky 2004](https://journals.physiology.org/doi/full/10.1152/jn.01171.2003?url_ver=Z39.88-2003&rfr_id=ori:rid:crossref.org&rfr_dat=cr_pub%20%200pubmed).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "# Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import minimize\n",
    "from scipy.io import loadmat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "#@title Figure settings\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "nma_style = {\n",
    "    'figure.figsize' : (8, 6),\n",
    "    'figure.autolayout' : True,\n",
    "    'font.size' : 15,\n",
    "    'xtick.labelsize' : 'small',\n",
    "    'ytick.labelsize' : 'small',\n",
    "    'legend.fontsize' : 'small',\n",
    "    'axes.spines.top' : False,\n",
    "    'axes.spines.right' : False,\n",
    "    'xtick.major.size' : 5,\n",
    "    'ytick.major.size' : 5,\n",
    "}\n",
    "for key, value in nma_style.items():\n",
    "    plt.rcParams[key] = value\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "#@title Helper functions\n",
    "\n",
    "def plot_stim_and_spikes(stim, spikes, dt, nt=120):\n",
    "  \"\"\"Show time series of stim intensity and spike counts.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): vector of stimulus intensities\n",
    "    spikes (1D array): vector of spike counts\n",
    "    dt (number): duration of each time step\n",
    "    nt (number): number of time steps to plot\n",
    "\n",
    "  \"\"\"\n",
    "  timepoints = np.arange(120)\n",
    "  time = timepoints * dt\n",
    "\n",
    "  f, (ax_stim, ax_spikes) = plt.subplots(\n",
    "    nrows=2, sharex=True, figsize=(8, 5),\n",
    "  )\n",
    "  ax_stim.plot(time, stim[timepoints])\n",
    "  ax_stim.set_ylabel('Stimulus intensity')\n",
    "\n",
    "  ax_spikes.plot(time, spikes[timepoints])\n",
    "  ax_spikes.set_xlabel('Time (s)')\n",
    "  ax_spikes.set_ylabel('Number of spikes')\n",
    "\n",
    "  f.tight_layout()\n",
    "\n",
    "\n",
    "def plot_glm_matrices(X, y, nt=50):\n",
    "  \"\"\"Show X and Y as heatmaps.\n",
    "\n",
    "  Args:\n",
    "    X (2D array): Design matrix.\n",
    "    y (1D or 2D array): Target vector.\n",
    "\n",
    "  \"\"\"\n",
    "  from matplotlib.colors import BoundaryNorm\n",
    "  from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "  Y = np.c_[y]  # Ensure Y is 2D and skinny\n",
    "\n",
    "  f, (ax_x, ax_y) = plt.subplots(\n",
    "    ncols=2,\n",
    "    figsize=(6, 8),\n",
    "    sharey=True,\n",
    "    gridspec_kw=dict(width_ratios=(5, 1)),\n",
    "  )\n",
    "  norm = BoundaryNorm([-1, -.2, .2, 1], 256)\n",
    "  imx = ax_x.pcolormesh(X[:nt], cmap=\"coolwarm\", norm=norm)\n",
    "\n",
    "  ax_x.set(\n",
    "    title=\"X\\n(lagged stimulus)\",\n",
    "    xlabel=\"Time lag (time bins)\",\n",
    "    xticks=[4, 14, 24],\n",
    "    xticklabels=['-20', '-10', '0'],\n",
    "    ylabel=\"Time point (time bins)\",\n",
    "  )\n",
    "  plt.setp(ax_x.spines.values(), visible=True)\n",
    "\n",
    "  divx = make_axes_locatable(ax_x)\n",
    "  caxx = divx.append_axes(\"right\", size=\"5%\", pad=0.1)\n",
    "  cbarx = f.colorbar(imx, cax=caxx)\n",
    "  cbarx.set_ticks([-.6, 0, .6])\n",
    "  cbarx.set_ticklabels(np.sort(np.unique(X)))\n",
    "\n",
    "  norm = BoundaryNorm(np.arange(y.max() + 1), 256)\n",
    "  imy = ax_y.pcolormesh(Y[:nt], cmap=\"magma\", norm=norm)\n",
    "  ax_y.set(\n",
    "    title=\"Y\\n(spike count)\",\n",
    "    xticks=[]\n",
    "  )\n",
    "  ax_y.invert_yaxis()\n",
    "  plt.setp(ax_y.spines.values(), visible=True)\n",
    "\n",
    "  divy = make_axes_locatable(ax_y)\n",
    "  caxy = divy.append_axes(\"right\", size=\"30%\", pad=0.1)\n",
    "  cbary = f.colorbar(imy, cax=caxy)\n",
    "  cbary.set_ticks(np.arange(y.max()) + .5)\n",
    "  cbary.set_ticklabels(np.arange(y.max()))\n",
    "\n",
    "def plot_spike_filter(theta, dt, **kws):\n",
    "  \"\"\"Plot estimated weights based on time lag model.\n",
    "\n",
    "  Args:\n",
    "    theta (1D array): Filter weights, not including DC term.\n",
    "    dt (number): Duration of each time bin.\n",
    "    kws: Pass additional keyword arguments to plot()\n",
    "\n",
    "  \"\"\"\n",
    "  d = len(theta)\n",
    "  t = np.arange(-d + 1, 1) * dt\n",
    "\n",
    "  ax = plt.gca()\n",
    "  ax.plot(t, theta, marker=\"o\", **kws)\n",
    "  ax.axhline(0, color=\".2\", linestyle=\"--\", zorder=1)\n",
    "  ax.set(\n",
    "    xlabel=\"Time before spike (s)\",\n",
    "    ylabel=\"Filter weight\",\n",
    "  )\n",
    "\n",
    "\n",
    "def plot_spikes_with_prediction(\n",
    "    spikes, predicted_spikes, dt, nt=50, t0=120, **kws):\n",
    "  \"\"\"Plot actual and predicted spike counts.\n",
    "\n",
    "  Args:\n",
    "    spikes (1D array): Vector of actual spike counts\n",
    "    predicted_spikes (1D array): Vector of predicted spike counts\n",
    "    dt (number): Duration of each time bin.\n",
    "    nt (number): Number of time bins to plot\n",
    "    t0 (number): Index of first time bin to plot.\n",
    "    kws: Pass additional keyword arguments to plot()\n",
    "\n",
    "  \"\"\"\n",
    "  t = np.arange(t0, t0 + nt) * dt\n",
    "\n",
    "  f, ax = plt.subplots()\n",
    "  lines = ax.stem(t, spikes[:nt], use_line_collection=True)\n",
    "  plt.setp(lines, color=\".5\")\n",
    "  lines[-1].set_zorder(1)\n",
    "  kws.setdefault(\"linewidth\", 3)\n",
    "  yhat, = ax.plot(t, predicted_spikes[:nt], **kws)\n",
    "  ax.set(\n",
    "      xlabel=\"Time (s)\",\n",
    "      ylabel=\"Spikes\",\n",
    "  )\n",
    "  ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))\n",
    "  ax.legend([lines[0], yhat], [\"Spikes\", \"Predicted\"])\n",
    "\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "#@title Data retrieval and loading\n",
    "import os\n",
    "import hashlib\n",
    "import requests\n",
    "\n",
    "fname = \"RGCdata.mat\"\n",
    "url = \"https://osf.io/mzujs/download\"\n",
    "expected_md5 = \"1b2977453020bce5319f2608c94d38d0\"\n",
    "\n",
    "if not os.path.isfile(fname):\n",
    "  try:\n",
    "    r = requests.get(url)\n",
    "  except requests.ConnectionError:\n",
    "    print(\"!!! Failed to download data !!!\")\n",
    "  else:\n",
    "    if r.status_code != requests.codes.ok:\n",
    "      print(\"!!! Failed to download data !!!\")\n",
    "    elif hashlib.md5(r.content).hexdigest() != expected_md5:\n",
    "      print(\"!!! Data download appears corrupted !!!\")\n",
    "    else:\n",
    "      with open(fname, \"wb\") as fid:\n",
    "        fid.write(r.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "-----\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "# Section 1: Linear-Gaussian GLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "## Section 1.1: Load retinal ganglion cell activity data\n",
    "\n",
    "In this exercise we use data from an experiment that presented a screen which randomly alternated between two luminance values and recorded responses from retinal ganglion cell (RGC), a type of neuron in the retina in the back of the eye. This kind of visual stimulus is called a \"full-field flicker\", and it was presented at ~120Hz (ie. the stimulus presented on the screen was refreshed about every 8ms). These same time bins were used to count the number of spikes emitted by each neuron.\n",
    "\n",
    "The file `RGCdata.mat` contains three variablies:\n",
    "\n",
    "- `Stim`, the stimulus intensity at each time point. It is an array with shape $T \\times 1$, where $T=144051$.\n",
    "\n",
    "-  `SpCounts`, the binned spike counts for 2 ON cells, and 2 OFF cells. It is a $144051 \\times 4$ array, and each column has counts for a different cell.\n",
    "\n",
    "- `dtStim`, the size of a single time bin (in seconds), which is needed for computing model output in units of spikes / s. The stimulus frame rate is given by `1 / dtStim`.\n",
    "\n",
    "Because these data were saved in MATLAB, where everything is a matrix, we will also process the variables to more Pythonic representations (1D arrays or scalars, where appropriate) as we load the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "data = loadmat('RGCdata.mat')  # loadmat is a function in scipy.io\n",
    "dt_stim = data['dtStim'].item()  # .item extracts a scalar value\n",
    "\n",
    "# Extract the stimulus intensity\n",
    "stim = data['Stim'].squeeze()  # .squeeze removes dimensions with 1 element\n",
    "\n",
    "# Extract the spike counts for one cell\n",
    "cellnum = 2\n",
    "spikes = data['SpCounts'][:, cellnum]\n",
    "\n",
    "# Don't use all of the timepoints in the dataset, for speed\n",
    "keep_timepoints = 20000\n",
    "stim = stim[:keep_timepoints]\n",
    "spikes = spikes[:keep_timepoints]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "Use the `plot_stim_and_spikes` helper function to visualize the changes in stimulus intensities and spike counts over time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "plot_stim_and_spikes(stim, spikes, dt_stim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "### Exercise 1: Create design matrix\n",
    "\n",
    "Our goal is to predict the cell's activity from the stimulus intensities preceding it. That will help us understand how RGCs process information over time. To do so, we first need to create the *design matrix* for this model, which organizes the stimulus intensities in matrix form such that the $i$th row has the stimulus frames preceding timepoint $i$.\n",
    "\n",
    "In this exercise, we will create the design matrix $X$ using $d=25$ time lags. That is, $X$ should be a $T \\times d$ matrix. $d = 25$ (about 200 ms) is a choice we're making based on our prior knowledge of the temporal window that influences RGC responses. In practice, you might not know the right duration to use.\n",
    "\n",
    "The last entry in row `t` should correspond to the stimulus that was shown at time `t`, the entry to the left of it should contain the value that was show one time bin earlier, etc. Specifically, $X_{ij}$ will be the stimulus intensity at time $i + d - 1 - j$.\n",
    "\n",
    "Note that for the first few time bins, we have access to the recorded spike counts but not to the stimulus shown in the recent past. For simplicity we are going to assume that values of `stim` are 0 for the time lags prior to the first timepoint in the dataset. This is known as \"zero-padding\", so that the design matrix has the same number of rows as the response vectors in `spikes`.\n",
    "\n",
    "Your task is is to complete the function below to:\n",
    "\n",
    "  - make a zero-padded version of the stimulus\n",
    "  - initialize an empty design matrix with the correct shape\n",
    "  - **fill in each row of the design matrix, using the zero-padded version of the stimulus**\n",
    "\n",
    "To visualize your design matrix (and the corresponding vector of spike counts), we will plot a \"heatmap\", which encodes the numerical value in each position of the matrix as a color. The helper functions include some code to do this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def make_design_matrix(stim, d=25):\n",
    "  \"\"\"Create time-lag design matrix from stimulus intensity vector.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): Stimulus intensity at each time point.\n",
    "    d (number): Number of time lags to use.\n",
    "\n",
    "  Returns\n",
    "    X (2D array): GLM design matrix with shape T, d\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  # Create version of stimulus vector with zeros before onset\n",
    "  padded_stim = np.concatenate([np.zeros(d - 1), stim])\n",
    "\n",
    "  #####################################################################\n",
    "  # Fill in missing code (...),\n",
    "  # then remove or comment the line below to test your function\n",
    "  raise NotImplementedError(\"Complete the make_design_matrix function\")\n",
    "  #####################################################################\n",
    "\n",
    "\n",
    "  # Construct a matrix where each row has the d frames of\n",
    "  # the stimulus proceeding and including timepoint t\n",
    "  T = len(...)  # Total number of timepoints (hint: number of stimulus frames)\n",
    "  X = np.zeros((T, d))\n",
    "  for t in range(T):\n",
    "      X[t] = ...\n",
    "\n",
    "  return X\n",
    "\n",
    "# Uncomment and run to test your function\n",
    "# X = make_design_matrix(stim)\n",
    "# plot_glm_matrices(X, spikes, nt=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def make_design_matrix(stim, d=25):\n",
    "  \"\"\"Create time-lag design matrix from stimulus intensity vector.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): Stimulus intensity at each time point.\n",
    "    d (number): Number of time lags to use.\n",
    "\n",
    "  Returns\n",
    "    X (2D array): GLM design matrix with shape T, d\n",
    "\n",
    "  \"\"\"\n",
    "  # Create version of stimulus vector with zeros before onset\n",
    "  padded_stim = np.concatenate([np.zeros(d - 1), stim])\n",
    "\n",
    "  # Construct a matrix where each row has the d frames of\n",
    "  # the stimulus proceeding and including timepoint t\n",
    "  T = len(stim)  # Total number of timepoints (hint: number of stimulus frames)\n",
    "  X = np.zeros((T, d))\n",
    "  for t in range(T):\n",
    "      X[t] = padded_stim[t:t + d]\n",
    "\n",
    "  return X\n",
    "\n",
    "X = make_design_matrix(stim)\n",
    "plot_glm_matrices(X, spikes, nt=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "## Section 1.2: Fit Linear-Gaussian regression model \n",
    "\n",
    "First, we will use the design matrix to compute the maximum likelihood estimate for a linear-Gaussian GLM (aka \"general linear model\"). The maximum likelihood estimate of $\\theta$ in this model can be solved analytically using the equation you learned about on Day 3:\n",
    "\n",
    "$$\\hat \\theta = (X^TX)^{-1}X^Ty$$\n",
    "\n",
    "Before we can apply this equation, we need to augment the design matrix to account for the mean of $y$, because the spike counts are all $\\geq 0$. We do this by adding a constant column of 1's to the design matrix, which will allow the model to learn an additive offset weight. We will refer to this additional weight as $b$ (for bias), although it is alternatively known as a \"DC term\" or \"intercept\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Build the full design matrix\n",
    "y = spikes\n",
    "constant = np.ones_like(y)\n",
    "X = np.column_stack([constant, make_design_matrix(stim)])\n",
    "\n",
    "# Get the MLE weights for the LG model\n",
    "theta = np.linalg.inv(X.T @ X) @ X.T @ y\n",
    "theta_lg = theta[1:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "Plot the resulting maximum likelihood filter estimate (just the 25-element weight vector $\\theta$ on the stimulus elements, not the DC term $b$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "plot_spike_filter(theta_lg, dt_stim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "---\n",
    "\n",
    "### Exercise 2: Predict spike counts with Linear-Gaussian model\n",
    "\n",
    "Now we are going to put these pieces together and write a function that outputs a predicted spike count for each timepoint using the stimulus information.\n",
    "\n",
    "Your steps should be:\n",
    "\n",
    "- Create the complete design matrix\n",
    "- Obtain the MLE weights ($\\hat \\theta$)\n",
    "- Compute $\\hat y = X\\hat \\theta$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def predict_spike_counts_lg(stim, spikes, d=25):\n",
    "  \"\"\"Compute a vector of predicted spike counts given the stimulus.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): Stimulus values at each timepoint\n",
    "    spikes (1D array): Spike counts measured at each timepoint\n",
    "    d (number): Number of time lags to use.\n",
    "\n",
    "  Returns:\n",
    "    yhat (1D array): Predicted spikes at each timepoint.\n",
    "\n",
    "  \"\"\"\n",
    "  ##########################################################################\n",
    "  # Fill in missing code (...) and then comment or remove the error to test\n",
    "  raise NotImplementedError(\"Complete the predict_spike_counts_lg function\")\n",
    "  ##########################################################################\n",
    "\n",
    "  # Create the design matrix\n",
    "  y = spikes\n",
    "  constant = ...\n",
    "  X = ...\n",
    "\n",
    "  # Get the MLE weights for the LG model\n",
    "  theta = ...\n",
    "\n",
    "  # Compute predicted spike counts\n",
    "  yhat = X @ theta\n",
    "  return yhat\n",
    "\n",
    "# Uncomment and run to test your function and plot prediction\n",
    "# predicted_counts = predict_spike_counts_lg(stim, spikes)\n",
    "# plot_spikes_with_prediction(spikes, predicted_counts, dt_stim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def predict_spike_counts_lg(stim, spikes, d=25):\n",
    "  \"\"\"Compute a vector of predicted spike counts given the stimulus.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): Stimulus values at each timepoint\n",
    "    spikes (1D array): Spike counts measured at each timepoint\n",
    "    d (number): Number of time lags to use.\n",
    "\n",
    "  Returns:\n",
    "    yhat (1D array): Predicted spikes at each timepoint.\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  # Create the design matrix\n",
    "  y = spikes\n",
    "  constant = np.ones_like(y)\n",
    "  X = np.column_stack([constant, make_design_matrix(stim)])\n",
    "\n",
    "  # Get the MLE weights for the LG model\n",
    "  theta = np.linalg.inv(X.T @ X) @ X.T @ y\n",
    "\n",
    "  # Compute predicted spike counts\n",
    "  yhat = X @ theta\n",
    "  return yhat\n",
    "\n",
    "predicted_counts = predict_spike_counts_lg(stim, spikes)\n",
    "plot_spikes_with_prediction(spikes, predicted_counts, dt_stim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "Is this a good model? The prediction line more-or-less follows the bumps in the spikes, but it never predicts as many spikes as are actually observed. And, more troublingly, it's predicting *negative* spikes for some time points.\n",
    "\n",
    "The Poisson GLM will help to address these failures.\n",
    "\n",
    "\n",
    "### Bonus challenge\n",
    "\n",
    "The \"spike-triggered average\" falls out as a subcase of the linear Gaussian GLM: $\\mathrm{STA} = X^T y \\,/\\, \\textrm{sum}(y)$, where $y$ is the vector of spike counts of the neuron. In the LG GLM, the term $(X^TX)^{-1}$ corrects for potential correlation between the regressors. Because the experiment that produced these data used a white noise stimulus, there are no such correlations. Therefore the two methods are equivalent. (How would you check the statement about no correlations?)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "# Section 2: Linear-Nonlinear-Poisson GLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "## Section 2.1: Nonlinear optimization with `scipy.optimize`\n",
    "\n",
    "Before diving into the Poisson GLM case, let us review the use and importance of convexity in optimization:\n",
    "- We have seen previously that in the Linear-Gaussian case, maximum likelihood  parameter estimate can be computed analytically. That is great because it only takes us a single line of code!\n",
    "- Unfortunately in general there is no analytical solution to our statistical estimation problems of interest. Instead, we need to apply a nonlinear optimization algorithm to find the parameter values that minimize some *objective function*. This can be extremely tedious because there is no general way to check whether we have found *the optimal solution* or if we are just stuck in some local minimum.\n",
    "- Somewhere in between theses two extremes, the spetial case of convex objective function is of great practical importance. Indeed, such optimization problems can be solved very reliably (and usually quite rapidly too!) using some standard software.\n",
    "\n",
    "Notes:\n",
    "- a function is convex if and only if its curve lies below any chord joining two of its points\n",
    "- to learn more about optimization, you can consult the book of Stephen Boyd and Lieven Vandenberghe [Convex Optimization](https://web.stanford.edu/~boyd/cvxbook/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "Here we will use the `scipy.optimize` module, it contains a function called [`minimize`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html) that provides a generic interface to a large number of optimization algorithms. This function expects as argument an objective function and an \"initial guess\" for the parameter values. It then returns a dictionary that includes the minimum function value, the parameters that give this minimum, and other information.\n",
    "\n",
    "Let's see how this works with a simple example. We want to minimize the function $f(x) = x^2$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "f = np.square\n",
    "\n",
    "res = minimize(f, x0=2)\n",
    "print(\n",
    "  f\"Minimum value: {res['fun']:.4g}\",\n",
    "  f\"at x = {res['x']}\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "When minimizing a $f(x) = x^2$, we get a minimum value of $f(x) \\approx 0$ when $x \\approx 0$. The algorithm doesn't return exactly $0$, because it stops when it gets \"close enough\" to a minimum. You can change the `tol` parameter to control how it defines \"close enough\".\n",
    "\n",
    "A point about the code bears emphasis. The first argument to `minimize` is not a number or a string but a *function*. Here, we used `np.square`. Take a moment to make sure you understand what's going on, because it's a bit unusual, and it will be important for the exercise you're going to do in a moment.\n",
    "\n",
    "In this example, we started at $x_0 = 2$. Let's try different values for the starting point:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "start_points = -1, 1.5\n",
    "\n",
    "xx = np.linspace(-2, 2, 100)\n",
    "plt.plot(xx, f(xx), color=\".2\")\n",
    "plt.xlabel(\"$x$\")\n",
    "plt.ylabel(\"$f(x)$\")\n",
    "\n",
    "for i, x0 in enumerate(start_points):\n",
    "  res = minimize(f, x0)\n",
    "  plt.plot(x0, f(x0), \"o\", color=f\"C{i}\", ms=10, label=f\"Start {i}\")\n",
    "  plt.plot(res[\"x\"].item(), res[\"fun\"], \"x\", c=f\"C{i}\", ms=10, mew=2, label=f\"End {i}\")\n",
    "  plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "The runs started at different points (the dots), but they each ended up at roughly the same place (the cross): $f(x_\\textrm{final}) \\approx 0$. Let's see what happens if we use a different function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "g = lambda x: x / 5 + np.cos(x)\n",
    "start_points = -.5, 1.5\n",
    "\n",
    "xx = np.linspace(-4, 4, 100)\n",
    "plt.plot(xx, g(xx), color=\".2\")\n",
    "plt.xlabel(\"$x$\")\n",
    "plt.ylabel(\"$f(x)$\")\n",
    "\n",
    "for i, x0 in enumerate(start_points):\n",
    "  res = minimize(g, x0)\n",
    "  plt.plot(x0, g(x0), \"o\", color=f\"C{i}\", ms=10, label=f\"Start {i}\")\n",
    "  plt.plot(res[\"x\"].item(), res[\"fun\"], \"x\", color=f\"C{i}\", ms=10, mew=2, label=f\"End {i}\")\n",
    "  plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "Unlike $f(x) = x^2$, $g(x) = \\frac{x}{5} + \\cos(x)$ is not *convex*. We see that the final position of the minimization algorithm depends on the starting point, which adds a layer of comlpexity to such problems."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "### Exercise 3: Fitting the Poisson GLM and prediction spikes\n",
    "\n",
    "In this exercise, we will use [`scipy.optimize.minimize`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html) to compute maximum likelihood estimates for the filter weights in the Poissson GLM model with an exponential nonlinearity (LNP: Linear-Nonlinear-Poisson).\n",
    "\n",
    "In practice, this will involve filling out two functions.\n",
    "\n",
    "- The first should be an *objective function* that takes a design matrix, a spike count vector, and a vector of parameters. It should return a negative log likelihood.\n",
    "- The second function should take `stim` and `spikes`, build the design matrix and then use `minimize` internally, and return the MLE parameters.\n",
    "\n",
    "What should the objective function look like? We want it to return the negative log likelihood: $-\\log P(y \\mid X, \\theta).$\n",
    "\n",
    "In the Poisson GLM,\n",
    "\n",
    "$$\n",
    "\\log P(\\mathbf{y} \\mid X, \\theta) = \\sum_t \\log P(y_t \\mid \\mathbf{x_t},\\theta),\n",
    "$$\n",
    "\n",
    "where\n",
    "\n",
    "$$ P(y_t \\mid \\mathbf{x_t}, \\theta) = \\frac{\\lambda_t^{y_t}\\exp(-\\lambda_t)}{y_t!} \\text{, with rate } \\lambda_t = \\exp(\\mathbf{x_t}^{\\top} \\theta).$$\n",
    "\n",
    "Now, taking the log likelihood for all the data we obtain:\n",
    "$\\log P(\\mathbf{y} \\mid X, \\theta) = \\sum_t( y_t \\log(\\lambda_t) - \\lambda_t - \\log(y_t !)).$\n",
    "\n",
    "Because we are going to minimize the negative log likelihood with respct to the parameters $\\theta$, we can ignore the last term that does not depend on $\\theta$. For faster implementation, let us rewrite this in matrix notation:\n",
    "\n",
    "$$\\mathbf{y}^T \\log(\\mathbf{\\lambda}) - \\mathbf{1}^T \\mathbf{\\lambda} \\text{, with  rate } \\mathbf{\\lambda} = \\exp(X^{\\top} \\theta)$$\n",
    "\n",
    "Finally, don't forget to add the minus sign for your function to return the negative log likelihood."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def neg_log_lik_lnp(theta, X, y):\n",
    "  \"\"\"Return -loglike for the Poisson GLM model.\n",
    "\n",
    "  Args:\n",
    "    theta (1D array): Parameter vector.\n",
    "    X (2D array): Full design matrix.\n",
    "    y (1D array): Data values.\n",
    "\n",
    "  Returns:\n",
    "    number: Negative log likelihood.\n",
    "\n",
    "  \"\"\"\n",
    "  #####################################################################\n",
    "  # Fill in missing code (...), then remove the error\n",
    "  raise NotImplementedError(\"Complete the neg_log_lik_lnp function\")\n",
    "  #####################################################################\n",
    "\n",
    "  # Compute the Poisson log likeliood\n",
    "  rate = np.exp(X @ theta)\n",
    "  log_lik = y @ ... - ...\n",
    "\n",
    "  return ...\n",
    "\n",
    "\n",
    "def fit_lnp(stim, spikes, d=25):\n",
    "  \"\"\"Obtain MLE parameters for the Poisson GLM.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): Stimulus values at each timepoint\n",
    "    spikes (1D array): Spike counts measured at each timepoint\n",
    "    d (number): Number of time lags to use.\n",
    "\n",
    "  Returns:\n",
    "    1D array: MLE parameters\n",
    "\n",
    "  \"\"\"\n",
    "  #####################################################################\n",
    "  # Fill in missing code (...), then remove the error\n",
    "  raise NotImplementedError(\"Complete the fit_lnp function\")\n",
    "  #####################################################################\n",
    "\n",
    "  # Build the design matrix\n",
    "  y = spikes\n",
    "  constant = np.ones_like(y)\n",
    "  X = np.column_stack([constant, make_design_matrix(stim)])\n",
    "\n",
    "  # Use a random vector of weights to start (mean 0, sd .2)\n",
    "  x0 = np.random.normal(0, .2, d + 1)\n",
    "\n",
    "  # Find parameters that minmize the negative log likelihood function\n",
    "  res = minimize(..., args=(X, y))\n",
    "\n",
    "  return ...\n",
    "\n",
    "\n",
    "# Uncomment and run to test your function\n",
    "# theta_lnp = fit_lnp(stim, spikes)\n",
    "# plot_spike_filter(theta_lg[1:], dt_stim, color=\".5\", label=\"LG\")\n",
    "# plot_spike_filter(theta_lnp[1:], dt_stim, label=\"LNP\")\n",
    "# plt.legend(loc=\"upper left\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def neg_log_lik_lnp(theta, X, y):\n",
    "  \"\"\"Return -loglike for the Poisson GLM model.\n",
    "\n",
    "  Args:\n",
    "    theta (1D array): Parameter vector.\n",
    "    X (2D array): Full design matrix.\n",
    "    y (1D array): Data values.\n",
    "\n",
    "  Returns:\n",
    "    number: Negative log likelihood.\n",
    "\n",
    "  \"\"\"\n",
    "  # Compute the Poisson log likeliood\n",
    "  rate = np.exp(X @ theta)\n",
    "  log_lik = y @ np.log(rate) - rate.sum()\n",
    "  return -log_lik\n",
    "\n",
    "\n",
    "def fit_lnp(stim, spikes, d=25):\n",
    "  \"\"\"Obtain MLE parameters for the Poisson GLM.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): Stimulus values at each timepoint\n",
    "    spikes (1D array): Spike counts measured at each timepoint\n",
    "    d (number): Number of time lags to use.\n",
    "\n",
    "  Returns:\n",
    "    1D array: MLE parameters\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  # Build the design matrix\n",
    "  y = spikes\n",
    "  constant = np.ones_like(y)\n",
    "  X = np.column_stack([constant, make_design_matrix(stim)])\n",
    "\n",
    "  # Use a random vector of weights to start (mean 0, sd .2)\n",
    "  x0 = np.random.normal(0, .2, d + 1)\n",
    "\n",
    "  # Find parameters that minmize the negative log likelihood function\n",
    "  res = minimize(neg_log_lik_lnp, x0, args=(X, y))\n",
    "\n",
    "  return res[\"x\"]\n",
    "\n",
    "\n",
    "theta_lnp = fit_lnp(stim, spikes)\n",
    "plot_spike_filter(theta_lg[1:], dt_stim, color=\".5\", label=\"LG\")\n",
    "plot_spike_filter(theta_lnp[1:], dt_stim, label=\"LNP\")\n",
    "plt.legend(loc=\"upper left\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "Plotting the LG and LNP weights together, we see that they are broadly similar, but the LNP weights are generally larger. What does that mean for the model's ability to *predict* spikes? To see that, let's finish the exercise by filling out the `predict_spike_counts_lnp` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def predict_spike_counts_lnp(stim, spikes, theta=None, d=25):\n",
    "  \"\"\"Compute a vector of predicted spike counts given the stimulus.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): Stimulus values at each timepoint\n",
    "    spikes (1D array): Spike counts measured at each timepoint\n",
    "    theta (1D array): Filter weights; estimated if not provided.\n",
    "    d (number): Number of time lags to use.\n",
    "\n",
    "  Returns:\n",
    "    yhat (1D array): Predicted spikes at each timepoint.\n",
    "\n",
    "  \"\"\"\n",
    "  ###########################################################################\n",
    "  # Fill in missing code (...) and then remove the error to test\n",
    "  raise NotImplementedError(\"Complete the predict_spike_counts_lnp function\")\n",
    "  ###########################################################################\n",
    "\n",
    "  y = spikes\n",
    "  constant = np.ones_like(spikes)\n",
    "  X = np.column_stack([constant, make_design_matrix(stim)])\n",
    "  if theta is None:  # Allow pre-cached weights, as fitting is slow\n",
    "    theta = fit_lnp(X, y, d)\n",
    "\n",
    "  yhat = ...\n",
    "  return yhat\n",
    "\n",
    "# Uncomment and run to test predict_spike_counts_lnp\n",
    "# yhat = predict_spike_counts_lnp(stim, spikes, theta_lnp)\n",
    "# plot_spikes_with_prediction(spikes, yhat, dt_stim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab_type": "code",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def predict_spike_counts_lnp(stim, spikes, theta=None, d=25):\n",
    "  \"\"\"Compute a vector of predicted spike counts given the stimulus.\n",
    "\n",
    "  Args:\n",
    "    stim (1D array): Stimulus values at each timepoint\n",
    "    spikes (1D array): Spike counts measured at each timepoint\n",
    "    theta (1D array): Filter weights; estimated if not provided.\n",
    "    d (number): Number of time lags to use.\n",
    "\n",
    "  Returns:\n",
    "    yhat (1D array): Predicted spikes at each timepoint.\n",
    "\n",
    "  \"\"\"\n",
    "  y = spikes\n",
    "  constant = np.ones_like(spikes)\n",
    "  X = np.column_stack([constant, make_design_matrix(stim)])\n",
    "  if theta is None:  # Allow pre-cached weights, as fitting is slow\n",
    "    theta = fit_lnp(X, y, d)\n",
    "\n",
    "  yhat = np.exp(X @ theta)\n",
    "  return yhat\n",
    "\n",
    "yhat = predict_spike_counts_lnp(stim, spikes, theta_lnp)\n",
    "plot_spikes_with_prediction(spikes, yhat, dt_stim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "We see that the LNP model does a better job of fitting the actual spiking data. Importantly, it never predicts negative spikes!\n",
    "\n",
    "*Bonus:* Our statement that the LNP model \"does a better job\" is qualitative and based mostly on the visual appearance of the plot. But how would you make this a quantitative statement?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {}
   },
   "source": [
    "---\n",
    "# Summary\n",
    "\n",
    "In this first tutorial, we used two different models to learn something about how retinal ganglion cells respond to a flickering white noise stimulus. We learned how to construct a design matrix that we could pass to different GLMs, and we found that the Linear-Nonlinear-Poisson (LNP) model allowed us to predict spike rates better than a simple Linear-Gaussian (LG) model.\n",
    "\n",
    "In the next tutorial, we'll extend these ideas further. We'll meet yet another GLM — logistic regression — and we'll learn how to ensure good model performance even when the number of parameters `d` is large compared to the number of data points `N`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Notation\n",
    "\n",
    "\\begin{align}\n",
    "y &\\quad \\text{measurement or response, here: spike count}\\\\\n",
    "T &\\quad \\text{number of time points}\\\\\n",
    "d &\\quad \\text{input dimensionality}\\\\\n",
    "\\mathbf{X} &\\quad \\text{design matrix, dimensions: } T \\times d\\\\\n",
    "\\theta &\\quad \\text{parameter}\\\\\n",
    "\\hat \\theta &\\quad \\text{estimated parameter}\\\\\n",
    "\\hat y &\\quad \\text{estimated response}\\\\\n",
    "P(\\mathbf{y} \\mid \\mathbf{X}, \\theta) &\\quad \\text{probability of observing response } y \\text{ given design matrix } \\mathbf{X} \\text{ and parameters } \\theta \\\\\n",
    "\\mathrm{STA} &\\quad \\text{spike-triggered average}\\\\\n",
    "b &\\quad \\text{bias weight, intercept}\\\\\n",
    "\\end{align}"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W4_Tutorial2",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "toc-autonumbering": true
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
